{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "08bf4b89015df3fa62ec3e5cebfe3c3e5a181176f7e37ebd57af46c3a62ed99e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "使用ResNet改良版的“批量归一化、激活和卷积”结构"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def conv_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n",
    "    )\n",
    "\n",
    "    return blk\n",
    "\n",
    "\n",
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self, num_convs, in_channels, out_channels):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            in_c = in_channels + i * out_channels\n",
    "            net.append(conv_block(in_c, out_channels))\n",
    "        \n",
    "        self.net = nn.ModuleList(net)\n",
    "        self.out_channels = in_channels + num_convs * out_channels\n",
    "    \n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            X = torch.cat((X, Y), dim=1)\n",
    "        \n",
    "        return X\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "torch.Size([4, 23, 8, 8])"
      ]
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "blk = DenseBlock(2, 3, 10)\n",
    "X = torch.rand(4, 3, 8, 8)\n",
    "Y = blk(X)\n",
    "Y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transition_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size=1),\n",
    "        nn.AvgPool2d(kernel_size=2, stride=2)\n",
    "    )\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "    nn.BatchNorm2d(64),\n",
    "    nn.ReLU(),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "0 output shape:\t torch.Size([1, 64, 48, 48])\n1 output shape:\t torch.Size([1, 64, 48, 48])\n2 output shape:\t torch.Size([1, 64, 48, 48])\n3 output shape:\t torch.Size([1, 64, 24, 24])\nDenseBlock_0 output shape:\t torch.Size([1, 192, 24, 24])\ntransition_block_0 output shape:\t torch.Size([1, 96, 12, 12])\nDenseBlock_1 output shape:\t torch.Size([1, 224, 12, 12])\ntransition_block_1 output shape:\t torch.Size([1, 112, 6, 6])\nDenseBlock_2 output shape:\t torch.Size([1, 240, 6, 6])\ntransition_block_2 output shape:\t torch.Size([1, 120, 3, 3])\nDenseBlock_3 output shape:\t torch.Size([1, 248, 3, 3])\nBN output shape:\t torch.Size([1, 248, 3, 3])\nrelu output shape:\t torch.Size([1, 248, 3, 3])\nglobal_avg_pool output shape:\t torch.Size([1, 248, 1, 1])\nfc output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "num_channels, growth_rate = 64, 32 # num_channels为当前的通道数\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
    "    net.add_module(\"DenseBlock_%d\" % i, DB)\n",
    "\n",
    "    num_channels = DB.out_channels\n",
    "\n",
    "    if i != len(num_convs_in_dense_blocks) - 1:\n",
    "        net.add_module(\"transition_block_%d\" % i, transition_block(num_channels, num_channels // 2))\n",
    "        num_channels = num_channels // 2\n",
    "\n",
    "\n",
    "net.add_module(\"BN\", nn.BatchNorm2d(num_channels))\n",
    "net.add_module(\"relu\", nn.ReLU())\n",
    "net.add_module(\"global_avg_pool\", d2l.GlobalAvgPool2d())\n",
    "net.add_module(\"fc\", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10)))\n",
    "\n",
    "\n",
    "X = torch.rand((1, 1, 96, 96))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, 'output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.4603, train acc 0.837, test acc 0.848, time 76.2 sec\n",
      "epoch 2, loss 0.2746, train acc 0.898, test acc 0.891, time 75.8 sec\n",
      "epoch 3, loss 0.2320, train acc 0.915, test acc 0.858, time 76.5 sec\n",
      "epoch 4, loss 0.2104, train acc 0.922, test acc 0.894, time 75.7 sec\n",
      "epoch 5, loss 0.1905, train acc 0.930, test acc 0.904, time 75.4 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}